"""
Core regression library with model definitions. 
Takes the soft-actor critic tanh transform from
https://github.com/deepmind/acme/jax/networks/distributional.py.
"""

from typing import Any, Optional, NamedTuple, Iterable, Callable
import jax.numpy as jnp
import jax
import haiku as hk
import math
import numpy as np
import optax
import functools

from tqdm import tqdm
import tensorflow_probability.substrates.jax as tfp

hk_init = hk.initializers
tfd = tfp.distributions


class TanhTransformedDistribution(tfd.TransformedDistribution):
    """Distribution followed by tanh."""

    def __init__(self, distribution, threshold=0.999, validate_args=False):
        """Initialize the distribution.
        Args:
          distribution: The distribution to transform.
          threshold: Clipping value of the action when computing the logprob.
          validate_args: Passed to super class.
        """
        super().__init__(
            distribution=distribution,
            bijector=tfp.bijectors.Tanh(),
            validate_args=validate_args,
        )
        # Computes the log of the average probability distribution outside the
        # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for
        # log_prob_left and [atanh(threshold), inf] for log_prob_right.
        self._threshold = threshold
        inverse_threshold = self.bijector.inverse(threshold)
        # average(pdf) = p/epsilon
        # So log(average(pdf)) = log(p) - log(epsilon)
        log_epsilon = jnp.log(1.0 - threshold)
        # Those 2 values are differentiable w.r.t. model parameters, such that the
        # gradient is defined everywhere.
        self._log_prob_left = (
            self.distribution.log_cdf(-inverse_threshold) - log_epsilon
        )
        self._log_prob_right = (
            self.distribution.log_survival_function(inverse_threshold) - log_epsilon
        )

    def log_prob(self, event):
        # Without this clip there would be NaNs in the inner tf.where and that
        # causes issues for some reasons.
        event = jnp.clip(event, -self._threshold, self._threshold)
        # The inverse image of {threshold} is the interval [atanh(threshold), inf]
        # which has a probability of "log_prob_right" under the given distribution.
        return jnp.where(
            event <= -self._threshold,
            self._log_prob_left,
            jnp.where(
                event >= self._threshold, self._log_prob_right, super().log_prob(event)
            ),
        )

    def variance(self):
        deriv = 1.0 - self.mode() ** 2
        return self.distribution.variance() * deriv**2

    def mode(self):
        return self.bijector.forward(self.distribution.mode())

    def entropy(self, seed=None):
        # We return an estimation using a single sample of the log_det_jacobian.
        # We can still do some backpropagation with this estimate.
        return self.distribution.entropy() + self.bijector.forward_log_det_jacobian(
            self.distribution.sample(seed=seed), event_ndims=0
        )

    @classmethod
    def _parameter_properties(cls, dtype: Optional[Any], num_classes=None):
        td_properties = super()._parameter_properties(dtype, num_classes=num_classes)
        del td_properties["bijector"]
        return td_properties


class NormalTanhDistribution(hk.Module):
    """Module that produces a TanhTransformedDistribution distribution."""

    def __init__(
        self,
        num_dimensions: int,
        faithful: bool,
        min_scale: float = 1e-3,
        # w_init: hk_init.Initializer = hk.initializers.Orthogonal(),
        w_init=hk.initializers.VarianceScaling(0.1, "fan_in", "uniform"),
        b_init: hk_init.Initializer = hk_init.Constant(0.0),
    ):
        """Initialization.
        Args:
          num_dimensions: Number of dimensions of a distribution.
          min_scale: Minimum standard deviation.
          w_init: Initialization for linear layer weights.
          b_init: Initialization for linear layer biases.
        """
        super().__init__(name="Normal")
        self._min_scale = min_scale
        self.faithful = faithful
        self._loc_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init)
        self._scale_layer = hk.Linear(num_dimensions, w_init=w_init, b_init=b_init)

    def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution:
        loc = self._loc_layer(inputs)
        inputs_ = jax.lax.stop_gradient(inputs) if self.faithful else inputs
        scale = self._scale_layer(inputs_)
        log_scale = -5 + 5 * (0.5 + 0.5 * jnp.tanh(scale))
        scale = jnp.exp(log_scale)
        distribution = tfd.Normal(loc=loc, scale=scale)
        dist = tfd.Independent(
            TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1
        )
        if self.faithful:
            cut_dist = TanhTransformedDistribution(
                tfd.Normal(loc=jax.lax.stop_gradient(loc), scale=scale)
            )
            cut_dist = tfd.Independent(cut_dist, reinterpreted_batch_ndims=1)
            return dist, cut_dist
        else:
            return dist, dist


def make_sac_mlp(num_dimensions, hidden_layer_sizes, faithful=True):
    def sac_mlp(obs):
        network = hk.Sequential(
            [
                hk.nets.MLP(
                    list(hidden_layer_sizes),
                    w_init=hk.initializers.VarianceScaling(0.1, "fan_in", "uniform"),
                    activation=jax.nn.elu,
                    activate_final=True,
                ),
                NormalTanhDistribution(num_dimensions, faithful=faithful),
            ]
        )
        return network(obs)

    return sac_mlp


# Periodic activation functions used by stationary approximations.


@jax.jit
def _triangle_activation(x: jnp.ndarray) -> jnp.ndarray:
    z = jnp.floor(x / jnp.pi + 0.5)
    return (x - jnp.pi * z) * (-1) ** z


@jax.jit
def triangle_activation(x: jnp.ndarray) -> jnp.ndarray:
    pdiv2 = 1.570796326
    return _triangle_activation(x) / pdiv2


@jax.jit
def periodic_relu_activation(x: jnp.ndarray) -> jnp.ndarray:
    pdiv4 = 0.785398163
    pdiv2 = 1.570796326
    return (_triangle_activation(x) + _triangle_activation(x + pdiv2)) * pdiv4


@jax.jit
def sin_cos_activation(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.sin(x) + jnp.cos(x)


@jax.jit
def hard_sin(x: jnp.ndarray) -> jnp.ndarray:
    pdiv4 = 0.785398163  # n/4
    return periodic_relu_activation(x - pdiv4)


@jax.jit
def hard_cos(x: jnp.ndarray) -> jnp.ndarray:
    pdiv4 = 0.785398163  # n/4
    return periodic_relu_activation(x + pdiv4)


class StationaryNormalTanhDistribution(hk.Module):
    """Module that produces a TanhTransformedDistribution distribution."""

    def __init__(self, num_dimensions: int, num_features: int, faithful: bool):
        """Initialization.
        Args:
          num_dimensions: Number of dimensions of a distribution.
          min_scale: Minimum standard deviation.
          w_init: Initialization for linear layer weights.
          b_init: Initialization for linear layer biases.
        """
        super().__init__(name="HetStat")
        self.num_features = num_features
        self.num_dimensions = num_dimensions
        self.faithful = faithful

    def features(self, x):
        d = x.shape[-1]
        w = hk.get_parameter(
            "rweights", shape=[d, self.num_features // 2], init=hk_init.RandomNormal()
        )
        ph = x @ w
        pdiv2 = 1.570796326
        sf = 1 / math.sqrt(self.num_features)
        return sf * jnp.concatenate(
            (
                triangle_activation(ph - pdiv2),
                triangle_activation(ph + pdiv2),
            ),
            axis=-1,
        )

    def __call__(self, inputs: jnp.ndarray) -> tfd.Distribution:
        f = self.features(inputs)

        mu = hk.get_parameter(
            "mean",
            shape=[self.num_features, self.num_dimensions],
            init=hk_init.Constant(0.0),
        )
        cov_sqrt = hk.get_parameter(
            "cov_sqrt",
            shape=[self.num_dimensions, self.num_features, self.num_features],
            init=hk_init.Identity(1.5),
        )

        loc = f @ mu
        if self.faithful:
            f_ = jax.lax.stop_gradient(f)
        else:
            f_ = f

        cov = jnp.einsum("kij,kil->kjl", cov_sqrt, cov_sqrt)
        scale = jnp.sqrt(jnp.einsum("bi,kij,bj->bk", f_, cov, f_))
        distribution = tfd.Normal(loc=loc, scale=scale)
        dist = tfd.Independent(
            TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1
        )
        if self.faithful:
            cut_dist = TanhTransformedDistribution(
                tfd.Normal(loc=jax.lax.stop_gradient(loc), scale=scale)
            )
            cut_dist = tfd.Independent(cut_dist, reinterpreted_batch_ndims=1)
            return dist, cut_dist
        else:
            return dist, dist


def make_hetstat_mlp(num_dimensions, hidden_layer_sizes, faithful=True):
    num_features = hidden_layer_sizes[-1]

    def hetstat_mlp(obs):
        network = hk.Sequential(
            [
                hk.nets.MLP(
                    list(hidden_layer_sizes[:-1]),
                    w_init=hk.initializers.VarianceScaling(0.1, "fan_in", "uniform"),
                    activation=jax.nn.elu,
                    activate_final=False,
                ),
                StationaryNormalTanhDistribution(
                    num_dimensions, num_features, faithful=True
                ),
            ]
        )
        return network(obs)

    return hetstat_mlp


def kl_estimator(dist_1, dist_2, key, n_samples=100):
    a = dist_1.sample(seed=key, sample_shape=n_samples)
    log_ratio = dist_1.log_prob(a) - dist_2.log_prob(a)
    kl = log_ratio.mean()
    return kl


class State(NamedTuple):
    opt_state: optax.OptState
    params: hk.Params


def train_model(
    network, x, y, faithful=True, refine=False, initial_param=None, n_iters=1000
):
    key = jax.random.PRNGKey(0)

    if initial_param is not None:
        param = initial_param
    else:
        param = network.init(key, x[0, None, :])

    opt = optax.adam(1e-4)
    opt_state = opt.init(param)

    state = State(opt_state, param)

    prior_dist = tfd.Normal(loc=jnp.zeros((1,)), scale=1 * jnp.ones((1,)))
    prior_dist = tfd.Independent(prior_dist, reinterpreted_batch_ndims=1)

    if faithful:

        def loss_fn(params, x, y, key):
            dist, cut_dist = network.apply(params, x)
            mse = ((y - dist.mode()) ** 2).mean()
            nllh = -cut_dist.log_prob(y).mean()
            if refine:
                x_samples = jax.random.uniform(
                    key, shape=(150, 1), minval=-20.0, maxval=20.0
                )
                dist, _ = network.apply(params, x_samples)
                reg = kl_estimator(dist, prior_dist, key)
                return mse + nllh + reg
            else:
                return mse + nllh

    else:  # nllh

        def loss_fn(params, x, y, key):
            dist, _ = network.apply(params, x)
            nllh = -dist.log_prob(y).mean()
            if refine:
                x_samples = jax.random.uniform(
                    key, shape=(150, 1), minval=-20.0, maxval=20.0
                )
                dist, _ = network.apply(params, x_samples)
                reg = kl_estimator(dist, prior_dist, key)
                return nllh + reg
            else:
                return nllh

    def step(state, x, y, key):
        params = state.params
        loss_value, grad = jax.value_and_grad(loss_fn, has_aux=False)(params, x, y, key)
        updates, opt_state = opt.update(grad, state.opt_state)
        params = optax.apply_updates(state.params, updates)
        return (
            State(
                opt_state,
                params,
            ),
            loss_value,
        )

    step = jax.jit(step)

    values = []
    for _ in tqdm(range(n_iters)):
        key, _ = jax.random.split(key)
        state, value = step(state, x, y, key)
        values += [value]

    def policy(x):
        return network.apply(state.params, x)[0]

    return policy, state.params
